
/**
  ******************************************************************************
  * Copyright 2021 The grapilot Authors. All Rights Reserved.
  * 
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  * 
  * http://www.apache.org/licenses/LICENSE-2.0
  * 
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  * 
  * @file       gp_matrix_alg.c
  * @author     baiyang
  * @date       2021-12-12
  ******************************************************************************
  */

/*----------------------------------include-----------------------------------*/
#include "gp_matrix_alg.h"

#include <rtthread.h>

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <float.h>
/*-----------------------------------macro------------------------------------*/

/*----------------------------------typedef-----------------------------------*/

/*---------------------------------prototype----------------------------------*/
static bool inverse3x3(const float* m, float* invOut);
static bool inverse4x4(const float* m,float* invOut);
static bool inverse9x9(const float* A, float* inv);
static void mat_LU_decompose9x9(const float* A, float* L, float* U, float *P, uint16_t n);
static bool mat_inverseN(const float* A, float* inv, uint16_t n);
static void mat_LU_decompose(const float* A, float* L, float* U, float *P, uint16_t n);
static void mat_pivot(const float* A, float* pivot, uint16_t n);
static void mat_forward_sub(const float *L, float *out, uint16_t n);
static void mat_back_sub(const float *U, float *out, uint16_t n);
static float* matrix_multiply(const float *A, const float *B, uint16_t n);
static inline void swap(float *a, float *b);
/*----------------------------------variable----------------------------------*/

/*-------------------------------------os-------------------------------------*/

/*----------------------------------function----------------------------------*/
bool mat_inverse(const float* x, float *y, uint16_t dim)
{
    switch(dim){
    case 3: return inverse3x3(x,y);
    case 4: return inverse4x4(x,y);
    case 9: return inverse9x9(x,y);
    default: return mat_inverseN(x, y, dim);
    }
}

/*
 *    fast matrix inverse code only for 3x3 square matrix
 *
 *    @param     m,           input 4x4 matrix
 *    @param     invOut,      Output inverted 4x4 matrix
 *    @returns                false = matrix is Singular, true = matrix inversion successful
 */
static bool inverse3x3(const float* m, float* invOut)
{
    float inv[9];
    // computes the inverse of a matrix m
    float  det = m[0] * (m[4] * m[8] - m[7] * m[5]) -
    m[1] * (m[3] * m[8] - m[5] * m[6]) +
    m[2] * (m[3] * m[7] - m[4] * m[6]);
    if (math_flt_zero(det) || isinf(det)) {
        return false;
    }

    float invdet = 1 / det;

    inv[0] = (m[4] * m[8] - m[7] * m[5]) * invdet;
    inv[1] = (m[2] * m[7] - m[1] * m[8]) * invdet;
    inv[2] = (m[1] * m[5] - m[2] * m[4]) * invdet;
    inv[3] = (m[5] * m[6] - m[3] * m[8]) * invdet;
    inv[4] = (m[0] * m[8] - m[2] * m[6]) * invdet;
    inv[5] = (m[3] * m[2] - m[0] * m[5]) * invdet;
    inv[6] = (m[3] * m[7] - m[6] * m[4]) * invdet;
    inv[7] = (m[6] * m[1] - m[0] * m[7]) * invdet;
    inv[8] = (m[0] * m[4] - m[3] * m[1]) * invdet;

    for(uint16_t i = 0; i < 9; i++){
        invOut[i] = inv[i];
    }

    return true;
}

/*
 *    fast matrix inverse code only for 4x4 square matrix copied from
 *    gluInvertMatrix implementation in opengl for 4x4 matrices.
 *
 *    @param     m,           input 4x4 matrix
 *    @param     invOut,      Output inverted 4x4 matrix
 *    @returns                false = matrix is Singular, true = matrix inversion successful
 */
static bool inverse4x4(const float* m,float *invOut)
{
    float inv[16], det;
    uint16_t i;

    inv[0] = m[5]  * m[10] * m[15] -
    m[5]  * m[11] * m[14] -
    m[9]  * m[6]  * m[15] +
    m[9]  * m[7]  * m[14] +
    m[13] * m[6]  * m[11] -
    m[13] * m[7]  * m[10];

    inv[4] = -m[4]  * m[10] * m[15] +
    m[4]  * m[11] * m[14] +
    m[8]  * m[6]  * m[15] -
    m[8]  * m[7]  * m[14] -
    m[12] * m[6]  * m[11] +
    m[12] * m[7]  * m[10];

    inv[8] = m[4]  * m[9] * m[15] -
    m[4]  * m[11] * m[13] -
    m[8]  * m[5] * m[15] +
    m[8]  * m[7] * m[13] +
    m[12] * m[5] * m[11] -
    m[12] * m[7] * m[9];

    inv[12] = -m[4]  * m[9] * m[14] +
    m[4]  * m[10] * m[13] +
    m[8]  * m[5] * m[14] -
    m[8]  * m[6] * m[13] -
    m[12] * m[5] * m[10] +
    m[12] * m[6] * m[9];

    inv[1] = -m[1]  * m[10] * m[15] +
    m[1]  * m[11] * m[14] +
    m[9]  * m[2] * m[15] -
    m[9]  * m[3] * m[14] -
    m[13] * m[2] * m[11] +
    m[13] * m[3] * m[10];

    inv[5] = m[0]  * m[10] * m[15] -
    m[0]  * m[11] * m[14] -
    m[8]  * m[2] * m[15] +
    m[8]  * m[3] * m[14] +
    m[12] * m[2] * m[11] -
    m[12] * m[3] * m[10];

    inv[9] = -m[0]  * m[9] * m[15] +
    m[0]  * m[11] * m[13] +
    m[8]  * m[1] * m[15] -
    m[8]  * m[3] * m[13] -
    m[12] * m[1] * m[11] +
    m[12] * m[3] * m[9];

    inv[13] = m[0]  * m[9] * m[14] -
    m[0]  * m[10] * m[13] -
    m[8]  * m[1] * m[14] +
    m[8]  * m[2] * m[13] +
    m[12] * m[1] * m[10] -
    m[12] * m[2] * m[9];

    inv[2] = m[1]  * m[6] * m[15] -
    m[1]  * m[7] * m[14] -
    m[5]  * m[2] * m[15] +
    m[5]  * m[3] * m[14] +
    m[13] * m[2] * m[7] -
    m[13] * m[3] * m[6];

    inv[6] = -m[0]  * m[6] * m[15] +
    m[0]  * m[7] * m[14] +
    m[4]  * m[2] * m[15] -
    m[4]  * m[3] * m[14] -
    m[12] * m[2] * m[7] +
    m[12] * m[3] * m[6];

    inv[10] = m[0]  * m[5] * m[15] -
    m[0]  * m[7] * m[13] -
    m[4]  * m[1] * m[15] +
    m[4]  * m[3] * m[13] +
    m[12] * m[1] * m[7] -
    m[12] * m[3] * m[5];

    inv[14] = -m[0]  * m[5] * m[14] +
    m[0]  * m[6] * m[13] +
    m[4]  * m[1] * m[14] -
    m[4]  * m[2] * m[13] -
    m[12] * m[1] * m[6] +
    m[12] * m[2] * m[5];

    inv[3] = -m[1] * m[6] * m[11] +
    m[1] * m[7] * m[10] +
    m[5] * m[2] * m[11] -
    m[5] * m[3] * m[10] -
    m[9] * m[2] * m[7] +
    m[9] * m[3] * m[6];

    inv[7] = m[0] * m[6] * m[11] -
    m[0] * m[7] * m[10] -
    m[4] * m[2] * m[11] +
    m[4] * m[3] * m[10] +
    m[8] * m[2] * m[7] -
    m[8] * m[3] * m[6];

    inv[11] = -m[0] * m[5] * m[11] +
    m[0] * m[7] * m[9] +
    m[4] * m[1] * m[11] -
    m[4] * m[3] * m[9] -
    m[8] * m[1] * m[7] +
    m[8] * m[3] * m[5];

    inv[15] = m[0] * m[5] * m[10] -
    m[0] * m[6] * m[9] -
    m[4] * m[1] * m[10] +
    m[4] * m[2] * m[9] +
    m[8] * m[1] * m[6] -
    m[8] * m[2] * m[5];

    det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];

    if (math_flt_zero(det) || isinf(det)){
        return false;
    }

    det = 1.0f / det;

    for (i = 0; i < 16; i++) {
        invOut[i] = inv[i] * det;
    }
    
    return true;
}

/*
 *    matrix inverse code for any square matrix using LU decomposition
 *    inv = inv(U)*inv(L)*P, where L and U are triagular matrices and P the pivot matrix
 *    ref: http://www.cl.cam.ac.uk/teaching/1314/NumMethods/supporting/mcmaster-kiruba-ludecomp.pdf
 *    @param     m,           input 4x4 matrix
 *    @param     inv,      Output inverted 4x4 matrix
 *    @param     n,           dimension of square matrix
 *    @returns                false = matrix is Singular, true = matrix inversion successful
 */
static bool inverse9x9(const float* A, float* inv)
{
    uint16_t n = 9;

    float L[9*9], U[9*9], P[9*9];
    bool ret = true;

    mat_LU_decompose9x9(A,L,U,P,9);

    float L_inv[9*9];
    float U_inv[9*9];

    rt_memset(L_inv,0,n*n*sizeof(float));

    // Forward substitution solve LY = I
    for(int i = 0; i < n; i++) {
        L_inv[i*n + i] = 1/L[i*n + i];
        for (int j = i+1; j < n; j++) {
            for (int k = i; k < j; k++) {
                L_inv[j*n + i] -= L[j*n + k] * L_inv[k*n + i];
            }
            L_inv[j*n + i] /= L[j*n + j];
        }
    }

    rt_memset(U_inv,0,n*n*sizeof(float));
    // Backward Substitution solve UY = I
    for(int i = n-1; i >= 0; i--) {
        U_inv[i*n + i] = 1/U[i*n + i];
        for (int j = i - 1; j >= 0; j--) {
            for (int k = i; k > j; k--) {
                U_inv[j*n + i] -= U[j*n + k] * U_inv[k*n + i];
            }
            U_inv[j*n + i] /= U[j*n + j];
        }
    }

    float inv_unpivoted[9*9];
    rt_memset(inv_unpivoted,0.0f,n*n*sizeof(float));

    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            for(uint16_t k = 0;k < n; k++) {
                inv_unpivoted[i*n + j] += U_inv[i*n + k] * L_inv[k*n + j];
            }
        }
    }

    float inv_pivoted[9*9];
    rt_memset(inv_pivoted,0.0f,n*n*sizeof(float));

    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            for(uint16_t k = 0;k < n; k++) {
                inv_pivoted[i*n + j] += inv_unpivoted[i*n + k] * P[k*n + j];
            }
        }
    }

    //check sanity of results
    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            if(isnan(inv_pivoted[i*n+j]) || isinf(inv_pivoted[i*n+j])){
                ret = false;
            }
        }
    }

    rt_memcpy(inv,inv_pivoted,n*n*sizeof(float));

    return ret;
}

/*
 *    Decomposes square matrix into Lower and Upper triangular matrices such that
 *    A*P = L*U, where P is the pivot matrix
 *    ref: http://rosettacode.org/wiki/LU_decomposition
 *    @param     U,           upper triangular matrix
 *    @param     out,         Output inverted upper triangular matrix
 *    @param     n,           dimension of matrix
 */
static void mat_LU_decompose9x9(const float* A, float* L, float* U, float *P, uint16_t n)
{
    rt_memset(L,0,n*n*sizeof(float));
    rt_memset(U,0,n*n*sizeof(float));
    rt_memset(P,0,n*n*sizeof(float));
    mat_pivot(A,P,n);

    float APrime[9*9];
    rt_memset(APrime,0.0f,9*9*sizeof(float));

    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            for(uint16_t k = 0;k < n; k++) {
                APrime[i*n + j] += P[i*n + k] * A[k*n + j];
            }
        }
    }
    
    for(uint16_t i = 0; i < n; i++) {
        L[i*n + i] = 1;
    }
    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            if(j <= i) {    
                U[j*n + i] = APrime[j*n + i];
                for(uint16_t k = 0; k < j; k++) {
                    U[j*n + i] -= L[j*n + k] * U[k*n + i]; 
                }
            }
            if(j >= i) {
                L[j*n + i] = APrime[j*n + i];
                for(uint16_t k = 0; k < i; k++) {
                    L[j*n + i] -= L[j*n + k] * U[k*n + i]; 
                }
                L[j*n + i] /= U[i*n + i];
            }
        }
    }
}

/*
 *    matrix inverse code for any square matrix using LU decomposition
 *    inv = inv(U)*inv(L)*P, where L and U are triagular matrices and P the pivot matrix
 *    ref: http://www.cl.cam.ac.uk/teaching/1314/NumMethods/supporting/mcmaster-kiruba-ludecomp.pdf
 *    @param     m,           input 4x4 matrix
 *    @param     inv,      Output inverted 4x4 matrix
 *    @param     n,           dimension of square matrix
 *    @returns                false = matrix is Singular, true = matrix inversion successful
 */
static bool mat_inverseN(const float* A, float* inv, uint16_t n)
{
    float *L, *U, *P;
    bool ret = true;
    L = (float *)rt_malloc(n*n);
    U = (float *)rt_malloc(n*n);
    P = (float *)rt_malloc(n*n);
    mat_LU_decompose(A,L,U,P,n);

    float *L_inv = (float *)rt_malloc(n*n);
    float *U_inv = (float *)rt_malloc(n*n);

    rt_memset(L_inv,0,n*n*sizeof(float));
    mat_forward_sub(L,L_inv,n);

    rt_memset(U_inv,0,n*n*sizeof(float));
    mat_back_sub(U,U_inv,n);

    // decomposed matrices no longer required
    rt_free(L);
    rt_free(U);

    float *inv_unpivoted = matrix_multiply(U_inv,L_inv,n);
    float *inv_pivoted = matrix_multiply(inv_unpivoted, P, n);

    //check sanity of results
    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            if(isnan(inv_pivoted[i*n+j]) || isinf(inv_pivoted[i*n+j])){
                ret = false;
            }
        }
    }
    rt_memcpy(inv,inv_pivoted,n*n*sizeof(float));

    //free memory
    rt_free(inv_pivoted);
    rt_free(inv_unpivoted);
    rt_free(P);
    rt_free(U_inv);
    rt_free(L_inv);
    
    return ret;
}

/*
 *    Decomposes square matrix into Lower and Upper triangular matrices such that
 *    A*P = L*U, where P is the pivot matrix
 *    ref: http://rosettacode.org/wiki/LU_decomposition
 *    @param     U,           upper triangular matrix
 *    @param     out,         Output inverted upper triangular matrix
 *    @param     n,           dimension of matrix
 */
static void mat_LU_decompose(const float* A, float* L, float* U, float *P, uint16_t n)
{
    rt_memset(L,0,n*n*sizeof(float));
    rt_memset(U,0,n*n*sizeof(float));
    rt_memset(P,0,n*n*sizeof(float));
    mat_pivot(A,P,n);

    float *APrime = matrix_multiply(P,A,n);
    for(uint16_t i = 0; i < n; i++) {
        L[i*n + i] = 1;
    }
    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            if(j <= i) {    
                U[j*n + i] = APrime[j*n + i];
                for(uint16_t k = 0; k < j; k++) {
                    U[j*n + i] -= L[j*n + k] * U[k*n + i]; 
                }
            }
            if(j >= i) {
                L[j*n + i] = APrime[j*n + i];
                for(uint16_t k = 0; k < i; k++) {
                    L[j*n + i] -= L[j*n + k] * U[k*n + i]; 
                }
                L[j*n + i] /= U[i*n + i];
            }
        }
    }
    rt_free(APrime);
}

/*
 *    calculates pivot matrix such that all the larger elements in the row are on diagonal
 *
 *    @param     A,           input matrix matrix
 *    @param     pivot
 *    @param     n,           dimenstion of square matrix
 *    @returns                false = matrix is Singular or non positive definite, true = matrix inversion successful
 */
static void mat_pivot(const float* A, float* pivot, uint16_t n)
{
    for(uint16_t i = 0;i<n;i++){
        for(uint16_t j=0;j<n;j++) {
            pivot[i*n+j] = (float)(i==j);
        }
    }

    for(uint16_t i = 0;i < n; i++) {
        uint16_t max_j = i;
        for(uint16_t j=i;j<n;j++){
            if(fabsf(A[j*n + i]) > fabsf(A[max_j*n + i])) {
                max_j = j;
            }
        }

        if(max_j != i) {
            for(uint16_t k = 0; k < n; k++) {
                swap(&pivot[i*n + k], &pivot[max_j*n + k]);
            }
        }
    }
}

/*
 *    calculates matrix inverse of Lower trangular matrix using forward substitution
 *
 *    @param     L,           lower triangular matrix
 *    @param     out,         Output inverted lower triangular matrix
 *    @param     n,           dimension of matrix
 */
static void mat_forward_sub(const float *L, float *out, uint16_t n)
{
    // Forward substitution solve LY = I
    for(int i = 0; i < n; i++) {
        out[i*n + i] = 1/L[i*n + i];
        for (int j = i+1; j < n; j++) {
            for (int k = i; k < j; k++) {
                out[j*n + i] -= L[j*n + k] * out[k*n + i];
            }
            out[j*n + i] /= L[j*n + j];
        }
    }
}

/*
 *    calculates matrix inverse of Upper trangular matrix using backward substitution
 *
 *    @param     U,           upper triangular matrix
 *    @param     out,         Output inverted upper triangular matrix
 *    @param     n,           dimension of matrix
 */
static void mat_back_sub(const float *U, float *out, uint16_t n)
{
    // Backward Substitution solve UY = I
    for(int i = n-1; i >= 0; i--) {
        out[i*n + i] = 1/U[i*n + i];
        for (int j = i - 1; j >= 0; j--) {
            for (int k = i; k > j; k--) {
                out[j*n + i] -= U[j*n + k] * out[k*n + i];
            }
            out[j*n + i] /= U[j*n + j];
        }
    }
}

/*
 *    Does matrix multiplication of two regular/square matrices
 *
 *    @param     A,           Matrix A
 *    @param     B,           Matrix B
 *    @param     n,           dimemsion of square matrices
 *    @returns                multiplied matrix i.e. A*B
 */
static float* matrix_multiply(const float *A, const float *B, uint16_t n)
{
    float* ret = (float *)rt_malloc(n*n);
    rt_memset(ret,0.0f,n*n*sizeof(float));

    for(uint16_t i = 0; i < n; i++) {
        for(uint16_t j = 0; j < n; j++) {
            for(uint16_t k = 0;k < n; k++) {
                ret[i*n + j] += A[i*n + k] * B[k*n + j];
            }
        }
    }
    return ret;
}

static inline void swap(float *a, float *b)
{
    float c;
    c = *a;
    *a = *b;
    *b = c;
}

/*------------------------------------test------------------------------------*/


